{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gr7_vyXBa2st"
      },
      "source": [
        "# Using SmoothQuant on OPT large models\n",
        "This notebook is a companion of chapter 8 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.  \n",
        "The code in this notebook is to show evidence that for LLMs having more 6 or more billion parameters, systematic outliers in a model's activations lead to a degradation in accuracy after quantization, and that the application of the [SmoothQuant](https://github.com/mit-han-lab/smoothquant) technique mitigates that risk. While the code refers to the Meta AI's [OPT 6.7 B](https://huggingface.co/facebook/opt-6.7b) model, the same applies to other models too. It requires hardware acceleration to be executed.  \n",
        "More details about the code can be found in the related book's chapter."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KNIFbSY0jLoB"
      },
      "source": [
        "Force the upgrade of the HF's Datasets library to the latest version. Restart the runtime at the end of this upgrade and before moving on with other cells code execution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VsrRohRGjaRo"
      },
      "outputs": [],
      "source": [
        "!pip install --force-reinstall datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZbPIRBTUcVxD"
      },
      "source": [
        "Install SmoothQuant from source."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3DlrH551CDuv"
      },
      "outputs": [],
      "source": [
        "!pip install git+https://github.com/mit-han-lab/smoothquant.git"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nThhlX98chpV"
      },
      "source": [
        "Import the required dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CD1M8UkYCVsb"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM\n",
        "from transformers import GPT2Tokenizer\n",
        "from smoothquant.smooth import smooth_lm\n",
        "from smoothquant.fake_quant import W8A8Linear"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LbSMBcy9cmXE"
      },
      "source": [
        "Define a custom finction to quantize a model (weights and activations) in INT8 precision."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CigNkKliCbVS"
      },
      "outputs": [],
      "source": [
        "def quantize_model(model, weight_quant='per_tensor', act_quant='per_tensor', quantize_bmm_input=True):\n",
        "    for name, m in model.model.named_modules():\n",
        "        if isinstance(m, OPTDecoderLayer):\n",
        "            m.fc1 = W8A8Linear.from_float(m.fc1, weight_quant=weight_quant,\n",
        "                                          act_quant=act_quant)\n",
        "            m.fc2 = W8A8Linear.from_float(m.fc2, weight_quant=weight_quant,\n",
        "                                          act_quant=act_quant)\n",
        "        elif isinstance(m, OPTAttention):\n",
        "            m.q_proj = W8A8Linear.from_float(\n",
        "                m.q_proj, weight_quant=weight_quant, act_quant=act_quant,\n",
        "                quantize_output=quantize_bmm_input)\n",
        "            m.k_proj = W8A8Linear.from_float(\n",
        "                m.k_proj, weight_quant=weight_quant, act_quant=act_quant,\n",
        "                quantize_output=quantize_bmm_input)\n",
        "            m.v_proj = W8A8Linear.from_float(\n",
        "                m.v_proj, weight_quant=weight_quant, act_quant=act_quant,\n",
        "                quantize_output=quantize_bmm_input)\n",
        "            m.out_proj = W8A8Linear.from_float(m.out_proj,\n",
        "                                               weight_quant=weight_quant, act_quant=act_quant)\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kjc7sVIIc5V8"
      },
      "source": [
        "Implementa a class to evaluate an LLM given a test dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n_B5F1w-CiUj"
      },
      "outputs": [],
      "source": [
        "class Evaluator:\n",
        "    def __init__(self, dataset, tokenizer, device):\n",
        "        self.dataset = dataset\n",
        "        self.tokenizer = tokenizer\n",
        "        self.device = device\n",
        "\n",
        "        def tokenize_function(examples):\n",
        "            example = self.tokenizer(examples['text'])\n",
        "            return example\n",
        "\n",
        "        self.dataset = self.dataset.map(tokenize_function, batched=True)\n",
        "        self.dataset.set_format(type='torch', columns=['input_ids'])\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def evaluate(self, model):\n",
        "        model.eval()\n",
        "        total, hit = 0, 0\n",
        "        for batch in self.dataset:\n",
        "            input_ids = batch['input_ids'].to(self.device).unsqueeze(0)\n",
        "            label = input_ids[:, -1]\n",
        "            outputs = model(input_ids)\n",
        "            last_token_logits = outputs.logits[:, -2, :]\n",
        "            pred = last_token_logits.argmax(dim=-1)\n",
        "            total += label.size(0)\n",
        "            hit += (pred == label).sum().item()\n",
        "        acc = hit / total\n",
        "        return acc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bl47DI6xdIUo"
      },
      "source": [
        "Download a subset (1000 samples in this case) of the LAMBADA dataset and the Meta AI OPT 6.7B model's tokenizer from the Hugging Face's Hub and then create an instance of the Evaluator class using them. Everything goes to GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IB2cUFLoCpiz"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "from transformers import GPT2Tokenizer\n",
        "import torch\n",
        "\n",
        "model_id = 'facebook/opt-6.7b'\n",
        "tokenizer = GPT2Tokenizer.from_pretrained(model_id)\n",
        "dataset = load_dataset('cimec/lambada', split='validation[:1000]')\n",
        "evaluator = Evaluator(dataset, tokenizer, 'cuda')\n",
        "\n",
        "print(\"Dataset loaded and Evaluator initialized successfully.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F_J8cAtLEGk6"
      },
      "source": [
        "#### FP16 Model Accuracy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wEjDztgMebk5"
      },
      "source": [
        "Download the Meta AI OPT 6.7B model in FP16 from the HF's Hub."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tLKSBJGZEDdk"
      },
      "outputs": [],
      "source": [
        "model_fp16 = OPTForCausalLM.from_pretrained(model_id,\n",
        "                                            torch_dtype=torch.float16,\n",
        "                                            device_map='auto',\n",
        "                                            offload_folder='.')\n",
        "model_fp16.eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QV0YtmCNejz_"
      },
      "source": [
        "Evaluate the model on the 1000 samples from the LAMBADA dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XcucIzD8EKiN"
      },
      "outputs": [],
      "source": [
        "acc_fp16 = evaluator.evaluate(model_fp16)\n",
        "print(f'Original model (fp16) accuracy: {acc_fp16}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ruIY-iq2EPpf"
      },
      "source": [
        "#### Naive W8A8 Quantized Model Accuracy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PSUBP9a2eseJ"
      },
      "source": [
        "Quantize weights and activation of the vanilla model (no SmoothQuant)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Iw9nJmg-EM_1"
      },
      "outputs": [],
      "source": [
        "model_w8a8 = quantize_model(model_fp16)\n",
        "print(model_w8a8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9m1wMMg-e0Tl"
      },
      "source": [
        "Evaluate the quantized model on the 1000 samples from the LAMBADA dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hEuo8ij1EUBX"
      },
      "outputs": [],
      "source": [
        "acc_w8a8 = evaluator.evaluate(model_w8a8)\n",
        "print(f'Naive W8A8 quantized model accuracy: {acc_w8a8}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nwxMG8nTEaq0"
      },
      "source": [
        "#### SmoothQuant W8A8 Quantized Model Accuracy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dWaNpVrDfQ3v"
      },
      "source": [
        "**To save time and free GPU memory to evaluate the model after applying SmoothQuant, a runtime restart is recommended at this time, before proceeding further.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "To6V3XW-e6FT"
      },
      "source": [
        "Download the specific model's scales from the HF's Hub (mandatory to apply SmoothQuant)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lv2VEOhLLRux"
      },
      "outputs": [],
      "source": [
        "!mkdir ./act_scales\n",
        "%cd act_scales\n",
        "!wget https://huggingface.co/mit-han-lab/smoothquant-scales/resolve/main/opt-6.7b.pt\n",
        "%cd .."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAqSspGQfy5l"
      },
      "source": [
        "Apply SmoothQuant and after quantize the vanilla model's weights and activations in INT8 format."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J6jzxr95gsUQ"
      },
      "outputs": [],
      "source": [
        "model_fp16 = OPTForCausalLM.from_pretrained(model_id,\n",
        "                                            torch_dtype=torch.float16,\n",
        "                                            device_map='auto',\n",
        "                                            offload_folder='.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H41rVIHjEbk3"
      },
      "outputs": [],
      "source": [
        "act_scales = torch.load('./act_scales/opt-6.7b.pt')\n",
        "smooth_lm(model_fp16, act_scales, 0.5)\n",
        "model_smoothquant_w8a8 = quantize_model(model_fp16)\n",
        "print(model_smoothquant_w8a8)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Kr5p55sgCHy"
      },
      "source": [
        "Evaluate the smooth quantized model on the 1000 samples from the LAMBADA dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_MjNqiPdiG_J"
      },
      "outputs": [],
      "source": [
        "model_smoothquant_w8a8.eval()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k_qv-WtGEdp-"
      },
      "outputs": [],
      "source": [
        "acc_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)\n",
        "print(f'SmoothQuant W8A8 quantized model accuracy: {acc_smoothquant_w8a8}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPUT6bVQgInt"
      },
      "source": [
        "The accuracy of the vanilla model and its smooth quantized version should be comparable, while there should be a significant drop (up to 40%) for the naive quantized model."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
